#!/usr/bin/env python
# -*- coding: utf-8; py-indent-offset:4 -*-
###############################################################################
#
# Copyright (C) 2015-2020 Daniel Rodriguez
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
###############################################################################
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)


import collections
import math
import os
import pandas as pd
import backtrader as bt
from backtrader.utils.py3 import items, iteritems

from . import NetAssetValue, PositionsValue, SQN, SharpeRatio, \
    TimeDrawDownHistory, AnnualReturn, MonthlyReturn, Trades, PeriodStats, Calmar, TimeReturn
from backtrader.utils.datehelper import *


class AnalysisReport(bt.Analyzer):

    params = (
        ('timeframe', bt.TimeFrame.Days),
        ('compression', 1),
        ('csv', False),
        ('out', None)
    )

    def __init__(self):
        dtfcomp = dict(timeframe=self.p.timeframe,
                       compression=self.p.compression)

        # 持仓标的的净值
        self._positions = PositionsValue(headers=True, cash=True)

        # 策略净值
        self._nav = NetAssetValue(**dtfcomp)

        # 区间收益统计
        self._stats = PeriodStats(**dtfcomp)

        # Sharp
        self._sharp = SharpeRatio(**dtfcomp, annualize=True)

        # SQN
        self._sqn = SQN()

        # 年收益率
        self._yearlyreturn = AnnualReturn()

        # 月收益率
        self._monthlyreturn = MonthlyReturn()

        # 日收益率
        self._dailyreturn = TimeReturn(**dtfcomp)

        # 历史成交
        self._trades_his = Trades(headers=True)

        # 回撤历史记录
        self._mdd_his = TimeDrawDownHistory(**dtfcomp)

        if self.p.csv and self.p.out and not os.path.exists(self.p.out):
            os.mkdir(self.p.out)

    def stop(self):
        super(AnalysisReport, self).stop()

        if not self.p.csv:
            return

            # 计算总的收益
        navs = list(self._nav.get_analysis().values())

        _total_rets = navs[-1] / navs[0] - 1

        # 收益统计
        period_stats = self._stats.get_analysis()
        _average_rets = period_stats["average"]
        _stddev_rets = period_stats["stddev"]
        _positive_rets = period_stats["positive"]
        _negative_rets = period_stats["negative"]
        _best_rets = period_stats["best"]
        _worst_rets = period_stats["worst"]

        # sharp
        sharp_stats = self._sharp.get_analysis()
        _sharp = sharp_stats["sharperatio"]

        # 回撤
        mdd_stats = self._mdd_his.get_analysis()
        _mdd = mdd_stats['maxdrawdown']
        _mddlen = mdd_stats['maxdrawdownperiod']

        # sqn
        sqn_stats = self._sqn.get_analysis()
        _sqn = sqn_stats["sqn"]
        _trade_count = sqn_stats["trades"]

        # summary
        summary = collections.OrderedDict()
        summary["total_rets"] = _total_rets
        summary["average_rets"] = _average_rets
        summary["stddev_rets"] = _stddev_rets
        summary["positive_rets"] = _positive_rets
        summary["negative_rets"] = _negative_rets
        summary["best_rets"] = _best_rets
        summary["worst_rets"] = _worst_rets
        summary["sharp"] = _sharp
        summary["mdd"] = _mdd
        summary["mddlen"] = _mddlen
        summary["sqn"] = _sqn
        summary["trade_count"] = _trade_count

        if summary:
            cols = ["item", "val"]
            fpath = os.path.join(self.p.out, "summary.csv")
            pd.DataFrame(columns=cols, data=[[k, v] for k, v in summary.items()]).to_csv(fpath, index=False)

        # position history

        position = self._positions.to_dataframe()
        if len(position):
            fpath = os.path.join(self.p.out, "positions.csv")
            position.to_csv(fpath, index=False)

        # return stats
        yearlyrets = self._yearlyreturn.get_analysis()
        monthlyrets = self._monthlyreturn.get_analysis()
        dailyrets = self._dailyreturn.get_analysis()

        rets_data = []
        for date, v in yearlyrets.items():
            rets_data.append([date, "year", v])

        for date, v in monthlyrets.items():
            rets_data.append([date, "month", v])

        for date, v in dailyrets.items():
            rets_data.append([date2str(date, _format_str="%Y%m%d"), "day", v])

        if rets_data:
            fpath = os.path.join(self.p.out, "returns.csv")
            cols = ["date", "type", "ret"]
            pd.DataFrame(columns=cols, data=rets_data).to_csv(fpath, index=False)

        # 净值历史
        nav = self._nav.to_dataframe()
        if len(nav):
            fpath = os.path.join(self.p.out, "nav.csv")
            nav.to_csv(fpath, index=False)

        # 回撤历史
        mdd_his = self._mdd_his.to_dataframe()

        if len(mdd_his):
            fpath = os.path.join(self.p.out, "mdd.csv")
            mdd_his.to_csv(fpath, index=False)

        # 历史成交
        trade_stats = self._trades_his.to_dataframe()

        if len(trade_stats):
            fpath = os.path.join(self.p.out, "trades.csv")
            trade_stats.to_csv(fpath, index=False)









